{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet(\n",
      "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): Bottleneck(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): Bottleneck(\n",
      "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (2): Bottleneck(\n",
      "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): Bottleneck(\n",
      "      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): Bottleneck(\n",
      "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (2): Bottleneck(\n",
      "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (3): Bottleneck(\n",
      "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): Bottleneck(\n",
      "      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): Bottleneck(\n",
      "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (2): Bottleneck(\n",
      "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (3): Bottleneck(\n",
      "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (4): Bottleneck(\n",
      "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (5): Bottleneck(\n",
      "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): Bottleneck(\n",
      "      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): Bottleneck(\n",
      "      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "    (2): Bottleneck(\n",
      "      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
      "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from torchvision import models\n",
    "transfer_model = models.resnet50(pretrained=True)\n",
    "print(transfer_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Freeze all parameters except batchnormalizations.\n",
    "for name, param in transfer_model.named_parameters():\n",
    "    if(\"bn\" not in name):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Changing last layer of resnet.\n",
    "import torch.nn as nn\n",
    "transfer_model.fc = nn.Sequential(nn.Linear(transfer_model.fc.in_features,500), \n",
    "                                  nn.ReLU(), \n",
    "                                  nn.Dropout(), \n",
    "                                  nn.Linear(500,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "transforms = transforms.Compose([transforms.Resize((64, 64)),\n",
    "                                 # Changes brightness, contrast and saturation values of images randomly.\n",
    "                                 transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0), \n",
    "                                 # Random flip process on images. p is chance of the reflection.\n",
    "                                 transforms.RandomHorizontalFlip(p=0.5),\n",
    "                                 transforms.RandomVerticalFlip(p=0.0),\n",
    "                                 # Applying gray scale on images by the chance of 10%.\n",
    "                                 transforms.RandomGrayscale(p=0.1),\n",
    "                                 # Random crop process on imgaes.\n",
    "                                 # transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='reflect'),\n",
    "                                 # transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),\n",
    "                                 # Applying rotation to images\n",
    "                                 # transforms.RandomRotation(degrees, resample=False,expand=False, center=None),\n",
    "                                 # Adding padding\n",
    "                                 # transforms.Pad(padding, fill=0, padding_mode=constant),\n",
    "                                 # Adding shear, degree etc.\n",
    "                                 transforms.RandomAffine(45, translate=None, scale=None, shear=0.4, resample=False, fillcolor=0),\n",
    "                                 transforms.ToTensor(), \n",
    "                                 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "train_data_path = \"./train/\"\n",
    "train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=transforms)\n",
    "\n",
    "val_data_path = \"./val/\"\n",
    "val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=transforms)\n",
    "\n",
    "test_data_path = \"./test/\"\n",
    "test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as data\n",
    "batch_size=64\n",
    "train_data_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "val_data_loader = data.DataLoader(val_data, batch_size=batch_size)\n",
    "test_data_loader = data.DataLoader(test_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "optimizer = optim.Adam(transfer_model.parameters(), lr=0.001)\n",
    "loss_fn = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
    "    for epoch in range(epochs):\n",
    "        training_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "        model.train()   # training mode\n",
    "        for batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = model(inputs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            training_loss += loss.data.item() * inputs.size(0)\n",
    "        training_loss /= len(train_loader.dataset)\n",
    "        \n",
    "        model.eval()   # evaluation mode for test.\n",
    "        num_correct = 0\n",
    "        num_examples = 0\n",
    "        for batch in val_loader:\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            output = model(inputs)\n",
    "            targets = targets.to(device)\n",
    "            loss = loss_fn(output,targets)\n",
    "            valid_loss += loss.data.item() * inputs.size(0)\n",
    "            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets).view(-1)\n",
    "            num_correct += torch.sum(correct).item()\n",
    "            num_examples += correct.shape[0]\n",
    "        valid_loss /= len(val_loader.dataset)\n",
    "        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss, valid_loss, num_correct / num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Training Loss: 0.70, Validation Loss: 0.38, accuracy = 0.85\n",
      "Epoch: 1, Training Loss: 0.47, Validation Loss: 0.24, accuracy = 0.94\n",
      "Epoch: 2, Training Loss: 0.34, Validation Loss: 0.20, accuracy = 0.90\n",
      "Epoch: 3, Training Loss: 0.26, Validation Loss: 0.20, accuracy = 0.90\n",
      "Epoch: 4, Training Loss: 0.25, Validation Loss: 0.27, accuracy = 0.89\n",
      "Epoch: 5, Training Loss: 0.20, Validation Loss: 0.18, accuracy = 0.88\n",
      "Epoch: 6, Training Loss: 0.19, Validation Loss: 0.15, accuracy = 0.95\n",
      "Epoch: 7, Training Loss: 0.20, Validation Loss: 0.20, accuracy = 0.92\n",
      "Epoch: 8, Training Loss: 0.18, Validation Loss: 0.21, accuracy = 0.91\n",
      "Epoch: 9, Training Loss: 0.15, Validation Loss: 0.25, accuracy = 0.91\n",
      "Epoch: 10, Training Loss: 0.18, Validation Loss: 0.14, accuracy = 0.93\n",
      "Epoch: 11, Training Loss: 0.11, Validation Loss: 0.24, accuracy = 0.86\n",
      "Epoch: 12, Training Loss: 0.13, Validation Loss: 0.09, accuracy = 0.95\n",
      "Epoch: 13, Training Loss: 0.14, Validation Loss: 0.21, accuracy = 0.88\n",
      "Epoch: 14, Training Loss: 0.11, Validation Loss: 0.23, accuracy = 0.86\n",
      "Epoch: 15, Training Loss: 0.13, Validation Loss: 0.22, accuracy = 0.94\n",
      "Epoch: 16, Training Loss: 0.11, Validation Loss: 0.18, accuracy = 0.91\n",
      "Epoch: 17, Training Loss: 0.10, Validation Loss: 0.20, accuracy = 0.95\n",
      "Epoch: 18, Training Loss: 0.10, Validation Loss: 0.24, accuracy = 0.89\n",
      "Epoch: 19, Training Loss: 0.09, Validation Loss: 0.21, accuracy = 0.94\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "train(transfer_model, optimizer, loss_fn, train_data_loader, val_data_loader, epochs=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model accuracy on test data: 92.59%\n"
     ]
    }
   ],
   "source": [
    "corrects = 0\n",
    "total = 0\n",
    "for batch in test_data_loader:\n",
    "    inputs, targets = batch\n",
    "    prediction = transfer_model(inputs)\n",
    "    correct = torch.eq(torch.max(F.softmax(prediction, dim=1), dim=1)[1], targets).view(-1)\n",
    "    correct_num = torch.sum(correct)\n",
    "    total_instance = correct.size()[0]\n",
    "    corrects += correct_num.item()\n",
    "    total += total_instance\n",
    "print(\"Model accuracy on test data: {}%\".format(round(corrects/total*100,2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finding That Learning Rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "alexnet = models.alexnet(num_classes=2)\n",
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(alexnet.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "def find_lr(model, loss_fn, optimizer,  train_loader, init_value=1e-4, final_value=1):\n",
    "    number_in_epoch = len(train_loader) - 1\n",
    "    update_step = (final_value / init_value) ** (1 / number_in_epoch)\n",
    "    lr = init_value\n",
    "    optimizer.param_groups[0][\"lr\"] = lr\n",
    "    best_loss = 0.0\n",
    "    batch_num = 0\n",
    "    losses = []\n",
    "    log_lrs = []\n",
    "    for data in train_loader:\n",
    "        batch_num += 1\n",
    "        inputs, labels = data\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = loss_fn(outputs, labels)\n",
    "        # Crash out if loss explodes\n",
    "        if batch_num > 1 and loss > 4 * best_loss:\n",
    "            return log_lrs, losses\n",
    "        # Record the best loss\n",
    "        if loss < best_loss or batch_num == 1:\n",
    "            best_loss = loss\n",
    "        # Store the values\n",
    "        losses.append(loss)\n",
    "        log_lrs.append(math.log10(lr))\n",
    "        # Do the backward pass and optimize\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        # Update the lr for the next step and store\n",
    "        lr *= update_step\n",
    "        optimizer.param_groups[0][\"lr\"] = lr\n",
    "    return log_lrs, losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1d88a4b40c8>]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dfXRU933n8fd3Rk9I6HEkxIN4kMaAAT+AAwQpj7ZjG3y2dpJtYmizTdqeONvGie1NuutsEtd1mp7Tpq3jpO4m7m5ONumusePYjtsgMIndJFtkGzDYDGBAEk8CDQgJCQmhp5nv/jF3yCAkNJJGc2dG39c5c5i593dnvrqMPnP1vXfuFVXFGGNM5vK4XYAxxpipZUFvjDEZzoLeGGMynAW9McZkOAt6Y4zJcFluFzBceXm5Llq0yO0yjDEmrezevfucqlaMNC/lgn7RokXs2rXL7TKMMSatiMjx0eZZ68YYYzKcBb0xxmQ4C3pjjMlwFvTGGJPhLOiNMSbDWdAbY0yGs6A3xpgMZ0FvzDQTCivPvHmC3oEht0sxSWJBb8w089q7Z/nKC/t4fneL26WYJLGgN2aaqQ8EAWhoane5EpMsFvTGTCODoTC/OHgGgIbmdsJhu8LcdGBBb8w08npzO12XBtlww2w6ewc5GLzgdkkmCSzojZlG6gNB8nO8/Lf11wPWvpku4gp6EVkvIodEpFFEHhlh/gIReU1E9ojIOyJytzP9DhHZLSL7nH9vS/QPYIyJTyisvLI/yK3Xz2JReQHV5QUW9NPEmEEvIl7gKWADsBzYJCLLhw37GvCcqq4CNgL/6Ew/B/yOqt4IfBr4caIKN8aMz65jHZzrGWDDDbMBqPX7eONoB0OhsMuVmakWzxb9WqBRVZtVdQDYDNw7bIwCRc79YuA0gKruUdXTzvT9QJ6I5E6+bGPMeNUHguRmebh16SwAamt89PQPse9Ul8uVmakWT9DPA07GPG5xpsV6DPiUiLQAW4AvjPA8/xHYo6r9w2eIyP0isktEdrW1tcVVuDEmfuGwsm1/kA8uqaAgN3K9oXU1PiBy9I3JbPEEvYwwbfgxWZuAH6pqFXA38GMRufzcIrIC+GvgcyO9gKo+raqrVXV1RcWIV8IyxkzC2y2dtHb1XW7bAFQU5rK0stD69NNAPEHfAsyPeVyF05qJ8cfAcwCq2gDkAeUAIlIFvAj8gao2TbZgY8z4bQ0EyfYKty+rvGJ6rd/HzmMd9A+FXKrMJEM8Qb8TWCwi1SKSQ2Rn68vDxpwAbgcQkWVEgr5NREqAnwNfUdV/T1zZxph4qSr1gSB1/nKKZ2RfMa/W76NvMMzbJ61Pn8nGDHpVHQIeALYBB4kcXbNfRB4XkXucYV8CPisibwPPAJ9RVXWWuw74uojsdW6zpuQnMcaM6EDrBU509F7RtolaV+1DBHY0nXOhMpMsWfEMUtUtRHayxk57NOb+AeB9Iyz3l8BfTrJGY8wkbA0E8QjcsbzyqnnF+dncMLeYHU3tPPQRF4ozSWHfjDUmw9UHgry32odv5shHNtf6few90cmlAevTZyoLemMyWOPZbhrP9rDhxqvbNlG1fh8DoTC7j59PYmUmmSzojclg9fsipyS+a8XoQb9mURlZHrE+fQazoDcmg9UHgrxnYSmVRXmjjpmZm8VNVcX2xakMZkFvTIY60d7LgdYLIx5tM1ydv5x3Wrro7htMQmUm2SzojclQ9YFW4Nptm6g6v49QWNl5rGOqyzIusKA3JkPVB4LcOK+Y+WX5Y469ZWEpOV4POxqtfZOJLOiNyUCtXZfYe7KT9XG0bQDysr3csrDE+vQZyoLemAy01bkAeDz9+ag6fzkHWi9w/uLAVJVlXGJBb0wGqg8EWVpZSE3FzLiXqfP7UIU3jtpWfaaxoDcmw7R197PzWEfcbZuom6pKmJHttdMWZyALemMyzCsHgqhyzW/DjiQny8Oa6jJ2WNBnHAt6YzLM1kCQ6vICllYWjnvZOr+PI2d7ONvdNwWVGbdY0BuTQTp7B2hoamf9DbMRGenicNdW61xe8PVmO54+k1jQG5NBth84w1BYx3W0TawVc4sozMuiwc57k1Es6I3JIFsDQeaVzODGecUTWj7L6+G91T7r02cYC3pjMkR33yC/OXJuwm2bqFq/j+PtvZzqvJTA6oybLOiNyRCvvnuWgVB4wm2bqDp/pE9vh1lmDgt6YzLE1kCQWYW53LKgdFLPs7SykLKCHDs/fQaJK+hFZL2IHBKRRhF5ZIT5C0TkNRHZIyLviMjdMfO+4ix3SETuSmTxxpiISwMh/u1QG3etmI3HM/G2DYDHI6yrKaOhqR1VTVCFxk1jBr2IeIGngA3AcmCTiCwfNuxrwHOqugrYCPyjs+xy5/EKYD3wj87zGWMS6FeHz3JpMDTub8OOptZfTmtXH8fbexPyfMZd8WzRrwUaVbVZVQeAzcC9w8YoUOTcLwZOO/fvBTarar+qHgUaneczxiRQfSBIaX42760uS8jzRfv0dvRNZogn6OcBJ2MetzjTYj0GfEpEWoAtwBfGsawxZhL6h0K8evAsdyyvJMubmN1uNeUFVBblWp8+Q8Tzrhip4Te8cbcJ+KGqVgF3Az8WEU+cyyIi94vILhHZ1dbWFkdJxpiof288R3f/EBtumJOw5xQRamt8vN5sffpMEE/QtwDzYx5X8dvWTNQfA88BqGoDkAeUx7ksqvq0qq5W1dUVFRXxV2+MoX5fkMLcLOqu8yX0eev85ZzrGeDI2Z6EPq9JvniCfiewWESqRSSHyM7Vl4eNOQHcDiAiy4gEfZszbqOI5IpINbAYeDNRxRsz3Q2Gwmw/eIbbl80iNyuxxznURvv0jda+SXdjBr2qDgEPANuAg0SOrtkvIo+LyD3OsC8BnxWRt4FngM9oxH4iW/oHgK3A51U1NBU/iDHT0RvNHXT2DrI+gW2bqPll+VSVzrAdshkgK55BqrqFyE7W2GmPxtw/ALxvlGW/CXxzEjUaY0ZRH2hlRraXDy2ZmpZnnd/Htv1nCIUV7ySPzzfusW/GGpOmQmFl2/4z3Hp9BTNypubrKXX+crouDXKw9cKUPL9JDgt6Y9LU7uPnOdfTPyVtm6jLfXo7zDKtWdAbk6bqA63kZHm47fpZU/YalUV51FQU2AnO0pwFvTFpSFXZFgjywcXlzMyNa1fbhNX5fbx5tIPBUHhKX8dMHQt6Y9LQ2y1dnO7qm9K2TVSdv5yLAyHeaema8tcyU8OC3pg0VB9oJcsj3LGscspfa93l68ha+yZdWdAbk2ZUla2BILV+H8X52VP+emUFOVw/u9B2yKYxC3pj0szB1m6Ot/cm9Nw2Y6nzl7Pr2Hn6h+z7junIgt6YNLM10IpH4M4VU9+2iar1++gfCrPnRGfSXtMkjgW9MWmmPhBkzaIyymfmJu0111aX4RE7P326sqA3Jo00nu3hyNmeSV8AfLyKZ2Rz47xiGqxPn5Ys6I1JI1sDrQBJOaxyuHV+H3tPdtI7MJT01zaTY0FvTBqpDwRZtaCE2cV5SX/tOn85gyFl17HzSX9tMzkW9MakiRPtvew/fSHpbZuoNYtKyfKI9enTkAW9MWli6/5I2yaZh1XGys/JYtWCEuvTpyELemPSRH0gyIq5Rcwvy3ethtoaH/tOdXGhb9C1Gsz4WdAbkwZauy6x50Sna22bqFp/OWGFN5s7XK3DjI8FvTFpYFsgCLhztE2sVQtKyM3yWJ8+zVjQG5MG6gNBFs+ayXWzZrpaR162l/csLKXBTnCWVizojUlx53r62Xmsw/W2TVSd38fB1gt0XBxwuxQTp7iCXkTWi8ghEWkUkUdGmP+EiOx1bodFpDNm3t+IyH4ROSgi3xERu8KwMePwyv4zhNX9tk1Urb8csNMWp5Mxg15EvMBTwAZgObBJRJbHjlHVh1V1paquBL4LvOAsWwe8D7gJuAFYA3wooT+BMRmuPtDKQl8+y+YUul0KADdVFZOf47XTFqeReLbo1wKNqtqsqgPAZuDea4zfBDzj3FcgD8gBcoFs4MzEyzVmeunqHaShqZ31N8wmVf4YzvZ6WFtdZteRTSPxBP084GTM4xZn2lVEZCFQDbwKoKoNwGtAq3PbpqoHR1jufhHZJSK72traxvcTGJPBth88w1BYXfuS1Gjq/D6a2i5y5kKf26WYOMQT9CNtRugoYzcCz6tqCEBErgOWAVVEPhxuE5EPXvVkqk+r6mpVXV1RURFf5cZMA1sDrcwtzuPmqmK3S7lCbU2kT29b9ekhnqBvAebHPK4CTo8ydiO/bdsAfAx4XVV7VLUHqAfWTaRQY6abnv4hfn3kHHelUNsmavncIorysizo00Q8Qb8TWCwi1SKSQyTMXx4+SESWAqVAQ8zkE8CHRCRLRLKJ7Ii9qnVjjLnaq++eZWAonHJtGwCvR1hX42NHs+2QTQdjBr2qDgEPANuIhPRzqrpfRB4XkXtihm4CNqtqbFvneaAJ2Ae8Dbytqv+SsOqNyWBbA62Uz8zlPQtL3S5lRHV+Hyc7LnGyo9ftUswYsuIZpKpbgC3Dpj067PFjIywXAj43ifqMmZYuDYR47d02Pn7LPLye1GrbREWPp29obnf1RGtmbPbNWGNS0K8Ot3FpMJSSbZuoJZUz8RXkWJ8+DVjQG5OCtgZaKcnP5r01ZW6XMioRodbvY0fTOa7s2JpUY0FvTIrpHwrxy4NnuWNZJdne1P4VrfX7OHOhn6PnLrpdirmG1H4XGTMN7Whsp7t/iA03psZJzK6lzunT22mLU5sFvTEppj7QSmFuFu+7rtztUsa0yJfPnOI869OnOAt6Y1LIUCjM9gNnuG3ZLHKzvG6XMyYRobbGR0NzO+Gw9elTlQW9MSnkzaMdnO8dTJlzz8ej1u+j4+IAh892u12KGYUFvTEppD4QZEa2lw8tmeV2KXGr9fuAyL4Fk5os6I1JEeGwsm1/kA8vrWBGTuq3baKqSvNZ6Mu3HbIpzILemBTx1onznO3uZ30atW2iamt8vHG0nZD16VOSBb0xKaI+ECTH6+G269OnbRNV6/fR3TfE/tNdbpdiRmBBb0wKUFW2BoJ8YHE5hXnZbpczbpf79Na+SUkW9MakgH2nujjVeSkt2zYAswrzuG7WTDuePkVZ0BuTAuoDQbI8wh3LK90uZcLq/D52HutgYCjsdilmGAt6Y1wWbdvU+n2U5Oe4Xc6E1fl99A6EeKel0+1SzDAW9Ma47NCZbo6eu5i2bZuo91b7ELE+fSqyoDfGZfX7gojAncvTO+hLC3JYNrvI+vQpyILeGJdtDQRZs6iMisJct0uZtDq/j90nztM3GHK7FBPDgt4YFzW39XDoTHdandvmWuqu8zEwFOat4+fdLsXEsKA3xkX1gSBA2vfno9YsKsPrERqarX2TSuIKehFZLyKHRKRRRB4ZYf4TIrLXuR0Wkc6YeQtE5BUROSgiB0RkUeLKNya9bQ0EWTm/hDnFM9wuJSEK87K5cV6x7ZBNMWMGvYh4gaeADcByYJOILI8do6oPq+pKVV0JfBd4IWb2j4BvqeoyYC1wNlHFG5POTnb0su9UV8a0baLq/D7ePtnJxf4ht0sxjni26NcCjararKoDwGbg3muM3wQ8A+B8IGSp6nYAVe1R1d5J1mxMRti2P7PaNlG1fh9DYWXnsQ63SzGOeIJ+HnAy5nGLM+0qIrIQqAZedSYtATpF5AUR2SMi33L+Qhi+3P0isktEdrW1tY3vJzAmTdUHgiybU8RCX4HbpSTU6oVlZHvFDrNMIfEEvYwwbbRzkW4EnlfV6LFVWcAHgC8Da4Aa4DNXPZnq06q6WlVXV1RUxFGSMentzIU+dh8/n3FtG4AZOV5WLSi1Pn0KiSfoW4D5MY+rgNOjjN2I07aJWXaP0/YZAl4CbplIocZkkmjbJhODHiLnpw+c7qKrd9DtUgzxBf1OYLGIVItIDpEwf3n4IBFZCpQCDcOWLRWR6Gb6bcCByZVsTPqr3xfEX1HA4spCt0uZEnV+H6rwxlHbqk8FYwa9syX+ALANOAg8p6r7ReRxEbknZugmYLOqasyyISJtm1+KyD4ibaB/SuQPYEy6ae/p542j7Wy4YY7bpUyZlQtKyMv2WPsmRWTFM0hVtwBbhk17dNjjx0ZZdjtw0wTrMybjbD9whrBm3tE2sXKzvKxeWGY7ZFOEfTPWmCSrDwSZXzaDFXOL3C5lStX6fRw60825nn63S5n2LOiNSaKuS4PsaDrHhhvmIDLSAW2Zo865vODrdjoE11nQG5NEvzx4hsGQZnTbJurGecXMzM2yPn0KsKA3JonqA0FmF+WxsqrE7VKmXJbXw9rqMl63oHedBb0xSXKxf4hfH25j/Q2z8Xgyu20TVef30XzuIq1dl9wuZVqzoDcmSV47dJb+ofC0aNtE1Tp9ejv6xl0W9MYkSX0gSPnMHNYsKnO7lKRZNruIkvxsC3qXWdAbkwR9gyFee/csdyyfjXeatG0APB5hXbWPHU3txHyX0iSZBb0xSfDrw230DoQy9tw211J3nY9TnZc42WF9erdY0BuTBFsDQYpnZF/uWU8ntTWRn3lH0zmXK5m+LOiNmWIDQ2G2HzzDR5ZVku2dfr9y182aSfnMXLuOrIum37vOmCTb0XSO7r6hadm2ARAR6vzWp3eTBb0xU2xrIEhBjpf3Ly53uxTX1Pl9tHX309TW43Yp05IFvTFTaCgU5pUDZ7htWSV52VddRXPasOPp3WVBb8wUevNYBx0XB6Zt2yZqQVk+80pm2HlvXGJBb8wU2hoIkpft4cNLp/e1kEWEWr+PhuZ2wmHr0yebBb0xUyQcVrYGgnxoSQX5OXFd4yej1db46Owd5N1gt9ulTDsW9MZMkT0nz3O2uz+jLxk4HtE+vR1Pn3wW9MZMkfp9QbK9wm3LZrldSkqYWzKD6vIC2yHrAgt6Y6aAqlIfCPL+68opyst2u5yUsa7GxxtHOxgKhd0uZVqJK+hFZL2IHBKRRhF5ZIT5T4jIXud2WEQ6h80vEpFTIvIPiSrcmFQWOHWBU52XrG0zTJ3fR0//EIHTF9wuZVoZcw+RiHiBp4A7gBZgp4i8rKoHomNU9eGY8V8AVg17mm8Av0pIxcakgfpAK16PcMfySrdLSSnrYs57s3J+5l9lK1XEs0W/FmhU1WZVHQA2A/deY/wm4JnoAxF5D1AJvDKZQo1JF6qRo23W1ZRRWpDjdjkppaIwl6WVhdanT7J4gn4ecDLmcYsz7SoishCoBl51HnuAvwP+7FovICL3i8guEdnV1tYWT93GpKzDZ3poPneR9da2GVGt38fOYx0MDFmfPlniCfqRrpIw2jceNgLPq2rIefynwBZVPTnK+MiTqT6tqqtVdXVFxfT+YolJf/WBVkTgrhXWthlJrd9H32CYvSc7xx5sEiKeb3G0APNjHlcBp0cZuxH4fMzjWuADIvKnwEwgR0R6VPWqHbrGZIqtgSCrF5YyqzDP7VJS0rpqHyKRPv3a6ulzWUU3xbNFvxNYLCLVIpJDJMxfHj5IRJYCpUBDdJqq/r6qLlDVRcCXgR9ZyJtMdvTcRd4Ndlvb5hqK87NZMbfI+vRJNGbQq+oQ8ACwDTgIPKeq+0XkcRG5J2boJmCz2gmnzTRWH2gFYP00P4nZWOr85ew50cmlgdDYg82kxXUCDlXdAmwZNu3RYY8fG+M5fgj8cFzVGZNmtgaC3FxVzLySGW6XktJq/T6e/nUzu4+fn9bn6U8W+2asMQnScr6Xd1q6rG0ThzWLyvB6hIZmO+9NMljQG5MgWwNBgGl/7vl4zMzN4uaqYjs/fZJY0BuTIFsDQa6fXcii8gK3S0kLdf5y3mnportv0O1SMp4FvTEJcPZCH7tPnLdz24xDnd9HKKzsPNbhdikZz4LemATYtj+IKmy40do28bplYSk5Xo8dZpkEFvTGJEB9IEhNRQGLZ810u5S0kZft5ZaFJdanTwILemMmqePiAG8c7WDDDbMRGemMIWY0df5yDrReoLN3wO1SMpoFvTGTtP1AkFBYrT8/AbV+H6rwerP16aeSBb0xk1QfCFJVOoMVc4vcLiXt3FxVwoxsLw12HdkpZUFvzCR0XRrk3xvPWdtmgnKyPKypLrM+/RSzoDdmEl599wyDIbVvw05CbY2PI2d7aOvud7uUjGVBb8wk1O8LUlmUyyq7LN6E1fkjlxdsaLat+qliQW/MBF3sH+JXh9tYv2I2Ho+1bSZqxdwiCvOyrE8/hSzojZmgfzvURv9Q2No2k5Tl9fBe69NPKQt6YyaoPtCKryDHrpKUALX+co6393Kq85LbpWQkC3pjJqBvMMRr757lzhWVeK1tM2mX+/S2VT8lLOiNmYDfHDnHxYGQtW0SZGllIWUFOeywPv2UsKA3ZgLqA60U5WVRW+Nzu5SM4PEI62rKeL2pHbsaaeJZ0BszTgNDYX5x4AwfWV5JTpb9CiVKrb+c0119HG/vdbuUjBPXu1RE1ovIIRFpFJFHRpj/hIjsdW6HRaTTmb5SRBpEZL+IvCMi9yX6BzAm2Rqa27nQN2TntkmwaJ/ejr5JvDGDXkS8wFPABmA5sElElseOUdWHVXWlqq4Evgu84MzqBf5AVVcA64Fvi4h9s8Skta2BVgpyvHzALmqdUDXlBcwqzLUvTk2BeLbo1wKNqtqsqgPAZuDea4zfBDwDoKqHVfWIc/80cBaomFzJxrgnFFZe2X+GW6+fRV621+1yMoqIUOf30dB0zvr0CRZP0M8DTsY8bnGmXUVEFgLVwKsjzFsL5ABNI8y7X0R2iciutra2eOo2xhVvHu2g/eKAtW2mSJ2/nHM9Axw52+N2KRklnqAf6SDh0T5uNwLPq2roiicQmQP8GPhDVQ1f9WSqT6vqalVdXVFhG/wmdW3bHyQ3y8OHl9r7dCrURvv0jXaYZSLFE/QtwPyYx1XA6VHGbsRp20SJSBHwc+Brqvr6RIo0JhWEw8rWQJAPLqmgIDfL7XIy0vyyfKpKZ1ifPsHiCfqdwGIRqRaRHCJh/vLwQSKyFCgFGmKm5QAvAj9S1Z8kpmRj3LG3pZPghT423GAXAJ9KdX4frzd3EApbnz5Rxgx6VR0CHgC2AQeB51R1v4g8LiL3xAzdBGzWK/eifBL4IPCZmMMvVyawfmOSZmsgSLZXuH1ZpdulZLQ6fzldlwY52HrB7VIyRlx/f6rqFmDLsGmPDnv82AjL/TPwz5Ooz5iU8EZzO8/vbqHOX07xjGy3y8lotTHnvblhXrHL1WQG+1qfMdcwGArzt9sOsemfXqcwL4v/fvcyt0vKeJVFedRUFNh5bxLI9igZM4rj7Rf54ua9vH2yk0+8p4o/v2cFM20nbFLU+X28+NYpBkNhsr22PTpZtgaNGUZVeX53C3c/+RuOtvXw1O/dwrc+cbOFfBLV1pRzcSDEvlNdbpeSEeyda0yMrkuDfPXFffzrO62srS7jiftWMq9khttlTTvraiIXc2loaueWBaUuV5P+bIveGMebRzu4+8nfUB8I8md3LeWZz66zkHeJb2Yu188utD59gtgWvZn2BkNhvvPLIzz1WiPzy/L56Z/UsXK+nXvPbbV+H//3jRP0D4XIzbLzCk2GbdGbae14+0U+8b0GvvtqIx+/pYqff/EDFvIpos5fTv9QmD0nOt0uJe3ZFr2ZllSVF946xaM/C+DxCN/dtIrfuXmu22WZGGury/BI5Pz06+xKXpNiQW+mna5Lg3ztpQD/8vZp1i4q4+/vu5mq0ny3yzLDFM/I5sZ5xTQ0nYM7lrhdTlqzoDfTys5jHTy0eS/BC3186Y4l/Omt1+H1jHSCVpMK1vl9/OD/HaV3YIj8HIuribIevZkWhkJh/v6VQ9z3/Qa8HuEn/7mWL9y+2EI+xdX5yxkMKbuOnXe7lLRmH5Em451o7+XBZ/ew50QnH79lHn9xzwoK8+x8NelgzaJSsjzCjqZ2PrjErgEwURb0JqO9uKeFr7+0HxH4zqZV3GM7XNNKfk4WK+eX2PnpJ8mC3mSkC32DfP2lAD/be5o1i0p54r6VtsM1TdX5ffzDa41c6BukyP4SmxDr0ZuMs+tY5Buu//pOK1+6Ywmb76+1kE9jtf5ywgpvNne4XUrasqA3GWMoFOaJ7Yf55PcbEMF2uGaIVQtKyMnyWPtmEqx1YzLCyY5eHnp2L7uPn+fjq+bxF/faDtdMkZftZfXCUnY0WdBPlAW9SXsv7TnF118KAPDkxpXcu3KeyxWZRKvz+/jbVw7TcXGAsoIct8tJOxb0Jm1d6Bvk0ZcCvLT3NKsXRna4zi+zXnwmil5e8PXmdu6+cY7L1aQfC3qTlnYf7+DBzXtp7erj4Y8s4fO3+smyKxFlrJuqSsjP8dLQZEE/EXH9ZojIehE5JCKNIvLICPOfEJG9zu2wiHTGzPu0iBxxbp9OZPFm+onucP3E9yI7XJ/7XC0PfmSxhXyGy/Z6WFtdZuenn6Axt+hFxAs8BdwBtAA7ReRlVT0QHaOqD8eM/wKwyrlfBvw5sBpQYLezrH2f2Yxb7A7Xj62ax+O2w3VaqfP7+KtDbZy50EdlUZ7b5aSVeDaD1gKNqtqsqgPAZuDea4zfBDzj3L8L2K6qHU64bwfWT6ZgMz39bO8p7n7yNxwOdvPkxpU8cd9KC/lppramHIj06c34xBP084CTMY9bnGlXEZGFQDXw6niWFZH7RWSXiOxqa2uLp24zTXT3DfLws3t5cPNelswuZMuDH7Cjaqap5XOLKMrLYkejBf14xbMzdqRvm+goYzcCz6tqaDzLqurTwNMAq1evHu25zTSz+/h5Hnp2D6fOX+KhjyzmgVuvs178NOb1COtqfOxotj79eMXzW9MCzI95XAWcHmXsRn7bthnvssYAkR2uT/7iCJ/8fgOqkW+4PvSRJRbyhlq/j5MdlzjZ0et2KWklni36ncBiEakGThEJ898bPkhElgKlQEPM5G3AX4lIqfP4TuArk6rYZLSTHb08/Oxedh0/z0dXzuXxj95gJ7Iyl9X5I336huZ2+87EOIwZ9Ko6JCIPEAltL/ADVd0vIo8Du1T1ZWfoJmCzqmrMsh0i8g0iHxYAj6uqnZnIjOhne0/xtRcDKPDt+1by0VXWiwMFBUUAAAstSURBVDdXWlI5E19BDg1N7Xxy9fyxFzBAnF+YUtUtwJZh0x4d9vixUZb9AfCDCdZnpoHuvkH+/Gf7eWHPKW5ZUMKTG1fZ1poZkYiwzu+joakdVUXETlgXj4z5ZmworBxvv0iWx4PXK3hF8HquvGVF74vgsTMapoS3Tpznoc17aTnfy4O3L+YLt9kOV3NtdX4fP3+nlaPnLlJTMdPtctJCxgR9Z+8At/3dr+IeL8LlD4MsTyT4s674UPDg8RD5VyL/Dv/g8A5bZviHy6jPK0KWVyL/jvKcOVkefAW5lM/MoaIwl/LCXApzszJmCyYUVp56rZEnf3mE2UV5PPe5WlYvKnO7LJMGon36HU3tFvRxypigL8jN4tv3rSQUVkJhZSishFQJhcKEFELhMENhJezMi/4bOz6szrSQs+wV88OEwpHnufx8IWUwFObS4NXPF3JefygU87wj3IbCYcJxHlCam+WJhP7M3Cv+rSjMpcL5QKiYmUd5YQ75Oan7X9tyPrLDdeex89y7ci7fsB2uZhwW+fKZU5xHQ1M7n1q30O1y0kLqpsE45WV703bnnepvPxiiHzr9g2HaL/ZzrnuAtp4+2rr7Odcz4Pzbz8mOXt46fp6O3gF0hA+Kghwv5YW5VAz7UIj9cCifmUP5zFzysr1J+1lffvs0X31xH6rwxH0387FVVUl7bZMZRITaGh+/OtxGOKzWho1DxgR9OhOnlXPFf0YeVBTmwuxrLzsUCtNxcYC2nn7auvuv+EBo6+nnXHc/R8720NDcTmfv4IjPUZiX5fw1kHvFh8PwDwnfzByyJ9g/7+kf4s9/tp+fvtXCqgUlPHnfKhb4bIermZhav48X9pzi8Nlurp9d5HY5ExIOK2e6+zjR3suJjl5OdvRSmJfNZz9Yk/DXsqBPc1leD7OK8pgVx0me+odCtPcMcK6n//JfBsM/HA6evsCvu/vp7h8a8TlK87Ov/Msg5sMh9kOirCDn8iX89pw4z0PP7uVkRy9fvH0xX7QdrmaSouen39HYntJB39M/xMmO3wb5iZhby/lLDAyFL4/1CLzvunILejM5uVle5pbMYG7JjDHH9g2GrviroG2ED4c9Jzpp6+7n0mDoquU9AmXOzuQjZ3uYXZTHs5+rZY3tcDUJUFWaz4KyfBqa2/mj91e7VkcorAQvRLbKhwf5yY5e2i8OXDG+MC+Lhb58rp9dyB3LK1lQln/5NrdkxoT/Yh6LBb0ZUV62l/ll+XEdz36xf+iqD4XYD4c1i8r48l1LKZ5hO1xN4tT5ffx8XyuhsE7pBeC7+wZH2CKPnIah5Xwvg6Hf7iTzeoR5JTNYUJbPnStmXxHkC8ryKc5353fAgt5MWkFuFgW5WSwqL3C7FDON1Pp9bN55kv2nu7ipqmTCzzMUCtPa1TfiFvmJjl7OD9u3VZKfzYKyfJbPLWL9DVeG+ZzivJRsS1rQG2PSUm1NpE/f0NQ+ZtB3XRocNchPnb/EUMwxzlkeoap0BvPL8rn7xjmXQzz6F246/mVqQW+MSUuzivK4btZMdjRF+vStnX0jBvmJjl66Ll25VV5WkMP8snxuqirhP9w053KQLyjLZ3ZRam6VT4YFvTEmbdX5ffz49eNc//WthGK2yrO9wvzSSHivnF9yRZDPL5sx7a5OZkFvjElb/2ndQi4NhJhdnHc5yBeU5VNZlDelO2jTjQW9MSZtLa4s5FufuNntMlJeZjWijDHGXMWC3hhjMpwFvTHGZDgLemOMyXAW9MYYk+Es6I0xJsNZ0BtjTIazoDfGmAwnOtJ16FwkIm3A8Uk8RTlwLkHlJJLVNT5W1/hYXeOTiXUtVNWKkWakXNBPlojsUtXVbtcxnNU1PlbX+Fhd4zPd6rLWjTHGZDgLemOMyXCZGPRPu13AKKyu8bG6xsfqGp9pVVfG9eiNMcZcKRO36I0xxsSwoDfGmAyX1kEvIl8WERWR8lHmf1pEjji3T8dMf4+I7BORRhH5jogk5FI0IvINEXlHRPaKyCsiMneEMbc686O3PhH5qDPvhyJyNGbeymTV5YwLxbz2yzHTq0XkDWc9PisiOcmqS0RWikiDiOx3xt4XM8/t9ZXs99e3RORdp7YXReSqK2KLyNJh768LIvKQM+8xETkVM+/uZNXljDvmrJe9IrIrZnqZiGx31uN2ESlNVl0iMl9EXhORg8577MGYeW6vr/Uicsh5Hz0SM338v4+qmpY3YD6wjciXq8pHmF8GNDv/ljr3S515bwK1gAD1wIYE1VQUc/+LwPfGGF8GdAD5zuMfAr87BesqrrqAnlGmPwdsdO5/D/iTZNUFLAEWO/fnAq1Aidvry6X3151AlnP/r4G/HmO8FwgS+SINwGPAl6dgfcVVF3BslN/VvwEece4/MtbPlci6gDnALc79QuAwsNzt9eX83zUBNUAO8HZMXeP+fUznLfongP8KjLY3+S5gu6p2qOp5YDuwXkTmEPlFbtDImvoR8NFEFKSqF2IeFlyjtqjfBepVtTcRrz+aCdR1mbM1ehvwvDPpf5PE9aWqh1X1iHP/NHAWGPHbf4kS5/py4/31iqoOOQ9fB6rGWOR2oElVJ/NN86moa7h7ibyvILHvrzHrUtVWVX3Lud8NHATmJeL1J1MXsBZoVNVmVR0ANgP3TvT3MS2DXkTuAU6p6tvXGDYPOBnzuMWZNs+5P3x6omr7poicBH4feHSM4RuBZ4ZN+6bzJ90TIpKb5LryRGSXiLwebScBPqAz5o3p2voSkbVEtm6aYia7tb5ceX/F+CMify1cy0jvrwec9fWDRLVIxlGXAq+IyG4RuT9meqWqtkIkeIFZSa4LABFZBKwC3oiZ7Nb6Gu39NaHfx5QNehH5hYgERrjdC3yVsUN0pL6oXmN6IupCVb+qqvOB/wM8cI3nmQPcSKT9FPUV4HpgDZGWwH9Lcl0LNPL1698Dvi0iflJrff0Y+ENVDTuT3Vxfrry/nDFfBYac2kZ7nhzgHuAnMZP/B+AHVhJpgf1dkut6n6reAmwAPi8iH4z39ae4LkRkJvBT4KGYv+jcXF+JfX8luv801Tci4XiWSL/vmLOiTgCzh43bBHw/5vH3nWlzgHdHG5fAOhcCgWvMfxB4+hrzPwz8a7Lrihn3QyKtJSFykqVoT7EW2JbMuoAi4C3gE6myvtx6fwGfBhpw9utcY9y9wCvXmL8onvdBouuKGf8YTv8bOATMce7PAQ4lsy4gm8gG139JlfU1/PeMyEbNVyb6+5iyW/SjUdV9qjpLVRep6iIif7rcoqrBYUO3AXeKSKnzJ9edRFZIK9AtIuucftcfAD9LRG0isjjm4T3Au9cYvolhf1Y7W63RvvhHgUCy6nLWU65zvxx4H3BAI++m14iEPkTeoElbX86W6YvAj1T1J8Pmuba+cOf9tZ7IXy336Nj7dUZ9fzk+RuLW15h1iUiBiBRG7xNZX9HXf5nI+woS+/6Kpy4B/hdwUFX/ftg819YXsBNYLJEjbHKItOFenvDvY6I+ody6EbMnH1gN/M+YeX8ENDq3P4yZvprIf1oT8A843xBOQC0/dZ73HeBfgHmj1LUIOAV4hi3/KrDPeY5/BmYmqy6gznntt51//zhm+RoiR5I0EmkF5Caxrk8Bg8DemNtKt9eXS++vRiJ92+h6+J4zfS6wJWZcPtAOFA9b/sfO+nqHSLjOSVZdznvobee2H/hqzPI+4JfAEeffsiTW9X4irY93Ysbd7fb6ch7fTeQooKZh62vcv492CgRjjMlwade6McYYMz4W9MYYk+Es6I0xJsNZ0BtjTIazoDfGmAxnQW+MMRnOgt4YYzLc/wcEV2hYYyouPAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "logs,losses = find_lr(alexnet, loss_fn, optimizer, train_data_loader)\n",
    "plt.plot(logs,losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "found_lr = 10**(-2.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Differential Learning Rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam([{ 'params': transfer_model.layer4.parameters(), 'lr': found_lr /3}, { 'params': transfer_model.layer3.parameters(), 'lr': found_lr /9}], \n",
    "                           lr=found_lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For unfreezing layers.\n",
    "unfreeze_layers = [transfer_model.layer3, transfer_model.layer4]\n",
    "for layer in unfreeze_layers:\n",
    "    for param in layer.parameters():\n",
    "        param.requires_grad = True"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
